﻿using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
    public class BaseSession
    {
        protected Graph _graph;
        protected bool _opened;
        protected bool _closed;
        protected int _current_version;
        protected byte[] _target;
        protected IntPtr _session;

        public BaseSession(string target = "", Graph graph = null)
        {
            if(graph is null)
            {
                _graph = ops.get_default_graph();
            }
            else
            {
                _graph = graph;
            }

            _target = UTF8Encoding.UTF8.GetBytes(target);
            var opts = c_api.TF_NewSessionOptions();
            var status = new Status();
            _session = c_api.TF_NewSession(_graph, opts, status);

            c_api.TF_DeleteSessionOptions(opts);
        }

        public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
        {
            return _run(fetches, feed_dict);
        }

        private NDArray _run(object fetches, FeedItem[] feed_dict = null)
        {
            var feed_dict_tensor = new Dictionary<object, object>();
            var feed_map = new Dictionary<object, object>();

            Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
            {
                return new (object, object)[] { (item.Key, item.Value) };
            };

            // Validate and process feed_dict.
            if (feed_dict != null)
            {
                foreach (var feed in feed_dict)
                {
                    foreach (var (subfeed, subfeed_val) in feed_fn(feed))
                    {
                        var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
                        var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();

                        switch (subfeed_val)
                        {
                            case IntPtr pointer:
                                feed_dict_tensor[subfeed_t] = pointer;
                                break;
                            case NDArray nd:
                                feed_dict_tensor[subfeed_t] = nd;
                                break;
                            case float floatVal:
                                feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
                                break;
                            case double doubleVal:
                                feed_dict_tensor[subfeed_t] = (NDArray)doubleVal;
                                break;
                            case int intVal:
                                feed_dict_tensor[subfeed_t] = (NDArray)intVal;
                                break;
                            case string str:
                                feed_dict_tensor[subfeed_t] = (NDArray)str;
                                break;
                            case byte[] bytes:
                                feed_dict_tensor[subfeed_t] = (NDArray)bytes;
                                break;
                            default:
                                Console.WriteLine($"can't handle data type of subfeed_val");
                                throw new NotImplementedException("_run subfeed");
                        }
                        
                        feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
                    }
                }
            }

            // Create a fetch handler to take care of the structure of fetches.
            var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);

            // Run request and get response.
            // We need to keep the returned movers alive for the following _do_run().
            // These movers are no longer needed when _do_run() completes, and
            // are deleted when `movers` goes out of scope when this _run() ends.
            var _ = _update_with_movers();
            var final_fetches = fetch_handler.fetches();
            var final_targets = fetch_handler.targets();

            // We only want to really perform the run if fetches or targets are provided,
            // or if the call is a partial run that specifies feeds.
            var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);

            return fetch_handler.build_results(this, results);
        }

        /// <summary>
        /// Runs a step based on the given fetches and feeds.
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="target_list">A list of operations to be run, but not fetched.</param>
        /// <param name="fetch_list"></param>
        /// <param name="feed_dict"></param>
        /// <returns>
        /// A list of numpy ndarrays, corresponding to the elements of
        /// `fetch_list`.  If the ith element of `fetch_list` contains the
        /// name of an operation, the first Tensor output of that operation
        /// will be returned for that element.
        /// </returns>
        private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
        {
            var feeds = feed_dict.Select(x => 
            {
                if(x.Key is Tensor tensor)
                {
                    switch (x.Value)
                    {
                        case IntPtr pointer:
                            return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
                        case Tensor t1:
                            return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
                        case NDArray nd:
                            return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
                        case int intVal:
                            return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
                        case float floatVal:
                            return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
                        case double doubleVal:
                            return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
                        default:
                            throw new NotImplementedException("feed_dict data type");
                    }
                }
                throw new NotImplementedException("_do_run.feed_dict");
            }).ToArray();
            var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
            var targets = target_list;

            return _call_tf_sessionrun(feeds, fetches, target_list);
        }

        private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
        {
            // Ensure any changes to the graph are reflected in the runtime.
            _extend_graph();

            var status = new Status();

            var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();

            c_api.TF_SessionRun(_session,
                run_options: null,
                inputs: feed_dict.Select(f => f.Key).ToArray(),
                input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
                ninputs: feed_dict.Length,
                outputs: fetch_list,
                output_values: output_values,
                noutputs: fetch_list.Length,
                target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
                ntargets: target_list.Count,
                run_metadata: IntPtr.Zero,
                status: status);

            status.Check(true);

            var result = new NDArray[fetch_list.Length];

            for (int i = 0; i < fetch_list.Length; i++)
            {
                result[i] = fetchValue(output_values[i]);
            }

            return result;
        }

        private unsafe NDArray fetchValue(IntPtr output)
        {
            var tensor = new Tensor(output);
            NDArray nd = null;
            Type type = tensor.dtype.as_numpy_datatype();
            var ndims = tensor.shape.Select(x => (int)x).ToArray();
            var offset = c_api.TF_TensorData(output);

            switch (tensor.dtype)
            {
                case TF_DataType.TF_STRING:
                    var bytes = tensor.Data();
                    // wired, don't know why we have to start from offset 9.
                    // length in the begin
                    var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
                    nd = np.array(str).reshape();
                    break;
                case TF_DataType.TF_INT16:
                    var shorts = new short[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(shorts).reshape(ndims);
                    break;
                case TF_DataType.TF_INT32:
                    var ints = new int[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(ints).reshape(ndims);
                    break;
                case TF_DataType.TF_FLOAT:
                    var floats = new float[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(floats).reshape(ndims);
                    break;
                case TF_DataType.TF_DOUBLE:
                    var doubles = new double[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(doubles).reshape(ndims);
                    break;
                default:
                    throw new NotImplementedException("can't fetch output");
            }

            return nd;
        }

        /// <summary>
        /// If a tensor handle that is fed to a device incompatible placeholder, 
        /// we move the tensor to the right device, generate a new tensor handle, 
        /// and update feed_dict to use the new handle.
        /// </summary>
        private List<object> _update_with_movers()
        {
            return new List<object> { };
        }

        private void _extend_graph()
        {

        }
    }
}
